from .utils import registry
from pydantic import conint, constr
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary, RichProgressBar, LearningRateMonitor, EarlyStopping
import time
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch import Callback
from lightning.pytorch.utilities import rank_zero_info


class CUDAMetricsCallback(Callback):
    def on_train_epoch_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        # Reset the memory use counter
        torch.cuda.reset_peak_memory_stats(self.root_gpu(trainer))
        torch.cuda.synchronize(self.root_gpu(trainer))
        self.start_time = time.time()

    def on_train_epoch_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        torch.cuda.synchronize(self.root_gpu(trainer))
        max_memory = torch.cuda.max_memory_allocated(self.root_gpu(trainer)) / 2**20
        epoch_time = time.time() - self.start_time

        max_memory = trainer.strategy.reduce(max_memory)
        epoch_time = trainer.strategy.reduce(epoch_time)

        rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
        rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")

    def root_gpu(self, trainer: "Trainer") -> int:
        return trainer.strategy.root_device.index
    
@registry.callbacks('cuda_metric')
def build_cuda_metric():
    return CUDAMetricsCallback()

@registry.callbacks('model_checkpoint')
def build_model_checkpoint(dirpath: str, monitor: str, filename: str = '{epoch}-{step}-{val/loss}', mode: str = 'min'):
    return ModelCheckpoint(monitor=monitor, dirpath=dirpath, filename=filename, mode=mode)

@registry.callbacks('model_summary')
def build_model_summary(max_depth: int):
    return RichModelSummary(max_depth=max_depth)

@registry.callbacks('rich_progress_bar')
def build_progress_bar():
    return RichProgressBar()

@registry.callbacks('lr_monitor')
def build_lr_monitor():
    return LearningRateMonitor()

@registry.callbacks('early_stop')
def build_early_stop(monitor: str = 'val/loss', patience: conint(ge=1) = 3, mode: constr(regex='(min|max)') = 'min', min_delta: float = 0):
    return EarlyStopping(monitor=monitor, patience=patience, mode=mode, min_delta=min_delta)